# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from itertools import combinations_with_replacement
from typing import Union, Sequence, List, Optional

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy import check
from brainpy.dnn.base import Layer

__all__ = [
    'NVAR'
]


def _comb(N, k):
    r"""The number of combinations of N things taken k at a time.

    .. math::

       \frac{N!}{(N-k)! k!}

    """
    if N > k:
        val = 1
        for j in range(min(k, N - k)):
            val = (val * (N - j)) // (j + 1)
        return val
    elif N == k:
        return 1
    else:
        return 0


class NVAR(Layer):
    """Nonlinear vector auto-regression (NVAR) node.

    This class has the following features:

    - it supports batch size,
    - it supports multiple orders,

    Parameters::

    delay: int
      The number of delay step.
    order: int, sequence of int
      The nonlinear order.
    stride: int
      The stride to sample linear part vector in the delays.
    constant: optional, float
      The constant value.

    References::

    .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation
           reservoir computing. Nat Commun 12, 5564 (2021).
           https://doi.org/10.1038/s41467-021-25801-2

    """

    def __init__(
        self,
        num_in: int,
        delay: int,
        order: Optional[Union[int, Sequence[int]]] = None,
        stride: int = 1,
        constant: bool = False,
        mode: Optional[bm.Mode] = None,
        name: Optional[str] = None,
    ):
        super(NVAR, self).__init__(mode=mode, name=name)

        # parameters
        order = tuple() if order is None else order
        if not isinstance(order, (tuple, list)):
            order = (order,)
        self.order = tuple(order)
        check.is_sequence(order, 'order', allow_none=False)
        for o in order:
            check.is_integer(o, 'order', allow_none=False, min_bound=2)
        check.is_integer(delay, 'delay', allow_none=False, min_bound=1)
        check.is_integer(stride, 'stride', allow_none=False, min_bound=1)
        assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.'
        self.delay = delay
        self.stride = stride
        self.constant = constant
        self.num_delay = 1 + (self.delay - 1) * self.stride
        self.num_in = num_in

        # delay variables
        self.idx = bm.Variable(jnp.asarray([0]))
        if isinstance(self.mode, bm.BatchingMode):
            batch_size = 1  # first initialize the state with batch size = 1
            self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1)
        else:
            self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in)))

        # linear dimension
        self.linear_dim = self.delay * num_in
        # For each monomial created in the non-linear part, indices
        # of the n components involved, n being the order of the
        # monomials. Precompute them to improve efficiency.
        self.comb_ids = []
        for order in self.order:
            assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.'
            idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order)))
            self.comb_ids.append(jnp.asarray(idx))
        # number of non-linear components is (d + n - 1)! / (d - 1)! n!
        # i.e. number of all unique monomials of order n made from the
        # linear components.
        self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids])
        # output dimension
        self.num_out = int(self.linear_dim + self.nonlinear_dim)
        if self.constant:
            self.num_out += 1

    def reset_state(self, batch_or_mode=None, **kwargs):
        """Reset the node state which depends on batch size."""
        self.idx[0] = 0
        # To store the last inputs.
        # Note, the batch axis is not in the first dimension, so we
        # manually handle the state of NVAR, rather return it.
        if batch_or_mode is None:
            self.store.value = jnp.zeros((self.num_delay, self.num_in))
        else:
            self.store.value = jnp.zeros((self.num_delay, batch_or_mode, self.num_in))

    def update(self, x):
        all_parts = []
        select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay
        # 1. Store the current input
        self.store[self.idx[0]] = x

        if isinstance(self.mode, bm.BatchingMode):
            # 2. Linear part:
            # select all previous inputs, including the current, with strides
            linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1)  # (num_batch, num_time, num_feature)
            linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1))
            # 3. constant
            if self.constant:
                constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype)
                all_parts.append(constant)
            all_parts.append(linear_parts)
            # 3. Nonlinear part:
            # select monomial terms and compute them
            for ids in self.comb_ids:
                all_parts.append(jnp.prod(linear_parts[:, ids], axis=2))

        else:
            # 2. Linear part:
            # select all previous inputs, including the current, with strides
            linear_parts = self.store[select_ids].flatten()  # (num_time x num_feature,)
            # 3. constant
            if self.constant:
                constant = jnp.ones((1,), dtype=x.dtype)
                all_parts.append(constant)
            all_parts.append(linear_parts)
            # 3. Nonlinear part:
            # select monomial terms and compute them
            for ids in self.comb_ids:
                all_parts.append(jnp.prod(linear_parts[ids], axis=1))

        # 4. Finally
        self.idx.value = (self.idx + 1) % self.num_delay
        return jnp.concatenate(all_parts, axis=-1)

    def get_feature_names(self, for_plot=False) -> List[str]:
        """Get output feature names for transformation.

        Parameters::

        for_plot: bool
          Use the feature names for plotting or not? (Default False)
        """
        if for_plot:
            linear_names = [f'x{i}_t' for i in range(self.num_in)]
        else:
            linear_names = [f'x{i}(t)' for i in range(self.num_in)]
        for di in range(1, self.delay):
            linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride))
                                  if for_plot else f'x{i}(t-{di * self.stride})')
                                 for i in range(self.num_in)])
        nonlinear_names = []
        for ids in self.comb_ids:
            for id_ in np.asarray(ids):
                uniques, counts = np.unique(id_, return_counts=True)
                nonlinear_names.append(" ".join(
                    "%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind]
                    for ind, exp in zip(uniques, counts)
                ))
        if for_plot:
            all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names]
        else:
            all_names = linear_names + nonlinear_names
        if self.constant:
            all_names = ['1'] + all_names
        return all_names
